import matplotlib.pyplot as plt
import torchvision
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from PIL import Image
import os


# 加载resnet18模型
resnet18 = models.resnet18(pretrained=False)
resnet101 = models.resnet101(pretrained=True)
densenet201 = models.densenet201(pretrained=True)
squeezenet_1_1 = models.squeezenet1_1(pretrained=True)

# 获取resnet18最后一层输出，输出为512维,最后一层本来是用作 分类的，原始网络分为1000类
# 用 softmax函数或者 fully connected 函数，但是用 nn.identtiy() 函数把最后一层替换掉，相当于得到分类之前的特征！
# Identity模块，它将输入直接传递给输出，而不会对输入进行任何变换。
resnet18.fc = nn.Identity()
resnet101.fc = nn.Identity()
densenet201.fc = nn.Identity()


# 构建新的网络，将resnet18的输出作为输入
# 构建新的网络，将resnet18的输出作为输入
class Attention(nn.Module):
    def __init__(self, in_channels):
        super(Attention, self).__init__()
        # 输入及输出都为3通道，不改变原始图片通道数
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)

    def forward(self, x):
        a = F.relu(self.conv(x))
        a = F.softmax(a.view(a.size(0), -1), dim=1).view_as(a)
        x = x * a
        return x


# 构建新的网络，将resnet18的输出作为输入
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 注意力，用于区分输入图片重要的部分
        self.attention1 = Attention(3)
        self.resnet18 = resnet18
        self.resnet101 = resnet101
        self.densenet201 = densenet201
        self.fc1 = nn.Linear(1000, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
        self.fc4 = nn.Linear(10, 2)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # x = self.attention1(x)
        # x = self.resnet18(x)
        # x = self.resnet101(x)
        x = self.densenet201(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.softmax(x)
        x = x.view(-1, 2)
        return x


# 实例化网络
model = Net()
# 将模型放入GPU
model = model.cuda()

# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器，添加l2正则化
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0)


# 加载数据集
# 创建一个transform对象
def rgb2bgr(image):
    image = np.array(image)[:, :, ::-1]
    image = Image.fromarray(np.uint8(image))
    return image


transform_list = [
    # transforms.Resize((224, 224)),
    # #transforms.ColorJitter的参数主要有：brightness，contrast，saturation和hue。
    # # brightness：用于调整图像的亮度，取值范围为[0, 1]，如果设置为0.5，则图像的亮度会随机增加或减少0-50%。
    # # contrast：用于调整图像的对比度，取值范围为[0, 1]，如果设置为0.5，则图像的对比度会随机增加或减少0-50%。
    # # saturation：用于调整图像的饱和度，取值范围为[0, 1]，如果设置为0.5，则图像的饱和度会随机增加或减少0-50%。
    # # hue：用于调整图像的色调，取值范围为[-0.5, 0.5]，如果设置为0.5，则图像的色调会随机增加或减少0-50%。
    #  transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=1),
    # # 原始图像在旋转-90度到90度也能被识别
    transforms.RandomRotation(degrees=(-45, 45)),
    # # 抵抗模糊的影响:参数3是模糊程度的参数，它表示高斯滤波器的核的大小
    transforms.GaussianBlur(3),
    # # size：表示裁剪出来的图像的大小，即输出图像的大小。
    # # scale：表示裁剪框的大小与原始图像的比例，范围是0.08到1.0，0.08表示原始图像的最小尺寸，1.0表示原始图像的最大尺寸。
    # # ratio：表示裁剪框的宽高比。
    # # interpolation：表示插值方法，2表示双线性插值。
    # torchvision.transforms.RandomResizedCrop((112, 112), scale=(0.8, 1), ratio=(1, 1), interpolation=2),
    # #0.5的概率水平翻转
    transforms.RandomHorizontalFlip(p=0.5),
    # degrees：可从中选择的度数范围。如果为非零数字，旋转角度从（-degrees,+degress),或者可设置为（min,max)
    # translate:水平和垂直平移的最大绝对偏移量。长度为2的元组，数值在(0,1)之间，dx在（-w*a,w*a)，dy在（-h*b,h*b)
    # scale:比例因子区间，例如（a，b），则从范围a<=比例<=b中随机采样比例。默认情况下，将保留原始比例。
    # shear:可从中选择的度数范围。放射变换的角度，若为 (a, b)，x 轴在 (-a, a) 之间随机选择错切角度，y 轴在 (-b, b) 之间随机选择错切角度,用灰色填充
    transforms.RandomAffine(degrees=(-45, 45), translate=(0.1, 0.3), scale=(0.8, 1.2), shear=(10, 10),
                            fill=(114, 114, 114))
]

transformR = transforms.Compose([transforms.RandomApply(transform_list, 0.1)])
transform = transforms.Compose([
    transformR,
    transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=0.5),
    # rgb转bgr
    torchvision.transforms.Lambda(rgb2bgr),
    torchvision.transforms.Resize((64, 64)),
    # 入的图片为PIL image 或者 numpy.nadrry格式的图片，其shape为（HxWxC）数值范围在[0,255],转换之后shape为（CxHxw）,数值范围在[0,1]
    transforms.ToTensor(),
    # 进行归一化和标准化，Imagenet数据集的均值和方差为：mean=(0.485, 0.456, 0.406)，std=(0.229, 0.224, 0.225)，
    # 因为这是在百万张图像上计算而得的，所以我们通常见到在训练过程中使用它们做标准化。
    transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]),
    # #这行代码表示使用transforms.RandomErasing函数，以概率p=1，在图像上随机选择一个尺寸为scale=(0.02, 0.33)，长宽比为ratio=(1, 1)的区域，
    # #进行随机像素值的遮盖，只能对tensor操作：
    # transforms.RandomErasing(p=0.1, scale=(0.02, 0.2), ratio=(1, 1), value='random')
])
transform2 = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=0.5),
    # rgb转bgr
    torchvision.transforms.Lambda(rgb2bgr),
    # 入的图片为PIL image 或者 numpy.nadrry格式的图片，其shape为（HxWxC）数值范围在[0,255],转换之后shape为（CxHxw）,数值范围在[0,1]
    transforms.ToTensor(),
    # 进行归一化和标准化，Imagenet数据集的均值和方差为：mean=(0.485, 0.456, 0.406)，std=(0.229, 0.224, 0.225)，
    # 因为这是在百万张图像上计算而得的，所以我们通常见到在训练过程中使用它们做标准化。
    transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
])
print("开始读取数据集")
train_dataset = torchvision.datasets.ImageFolder(r'./data/train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
test_dataset = torchvision.datasets.ImageFolder(r'./data/test', transform=transform2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=True)
print("读取数据集完成")

# 绘制训练及测试集迭代曲线
# 记录训练集准确率
train_acc = []
# 记录测试集准确率
test_acc = []
for epoch in range(50):
    print(f"{epoch} 开始训练")
    running_loss = 0.0
    # [(0, data1), (1, data2), (2, data3), ...]
    for i, data in enumerate(train_loader, 0):
        # 获取输入
        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()
        # 梯度清零
        optimizer.zero_grad()
        # forward + backward
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        # 更新参数
        optimizer.step()
        # 打印log信息
        # loss 是一个scalar,需要使用loss.item()来获取数值，不能使用loss[0]
        running_loss += loss.item()

        print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss))
        running_loss = 0.0
        # 在每次训练完成后，使用测试集进行测试
        correct = 0
        total = 0
        with torch.no_grad():
            for i2, data2 in enumerate(test_loader):
                # 控制测试集的数量
                if i2 > 5:
                    break
                images, labels = data2
                images, labels = images.cuda(), labels.cuda()
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_acc.append(100 * correct / total)
        print('Accuracy of the network on the test images: %.3f,now max acc is %.3f' % (
            100 * correct / total, max(test_acc)))
        # 保存测试集上准确率最高的模型
        if 100 * correct / total == max(test_acc):
            if not os.path.exists(r'./result'):
                os.makedirs(r'./result')
            if max(test_acc) > 95:
                savename = "./result/bestmodel" + "%.3f" % max(test_acc) + ".pth"
                torch.save(model.state_dict(), savename)

print("最大准确度：", max(test_acc))

# 绘制训练及测试集迭代曲线
plt.plot(range(len(train_acc)), train_acc, color='blue', label='Train Acc')
plt.plot(range(len(test_acc)), test_acc, color='red', label='Test Acc')
plt.legend()
plt.title("Accuracy Curve")
plt.show()


# 测试的脚本
# 在验证集上进行测试
# savename = "./result/bestmodel" + "%.3f" % max(test_acc) + ".pth"
# model.load_state_dict(torch.load(savename))
#
# val_dataset = torchvision.datasets.ImageFolder(r'D:\eyeDataSet\validate', transform=transform2)
# val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=True)
#
# # 验证集上的表现情况：
# correct = 0
# total = 0
# with torch.no_grad():
#     for data in val_loader:
#         images, labels = data
#         images, labels = images.cuda(), labels.cuda()
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
#
# print('Accuracy of the network on the validation images: %.3f %%' % (100 * correct / total))
